Seeing is believing

Using FlashTorch 🔦 to shine a light on what neural nets "see"


by Misa Ogura

Hello, I'm Misa 👋


  • Originally from Tokyo, now based in London
  • Cancer Cell Biologist, turned Software Engineer
  • Currently at BBC R&D
  • Co-founder of Women Driven Development
  • Women in Data Science London Ambassador

Feature visualisation


Image convolution & CNN 101


Kernel & convolution


Kernel: a small matrix used for blurring, sharpening, embossing, edge detection etc

Convolution: an operation to calculate weighted sum of local neibours

Examples of convolution: detecting edges


Typical CNN architecture


Kernels weights are learnt during the training to extract relevant features from input images.

Feature visualisation technique

Saliency maps


Saliency


  • A subjective quality in human visual perception

  • Makes certain items stand out and grabs our attention

Saliemcy maps in CNNs


  • First introduced in 2013

  • Indications of the most “salient” regions of an image

  • Focus on the gradients of output category (target class) w.r.t. input image

  • Small changes in pixels with positive gradients will increase in the probability of target class

  • Visualising the gradients provides some intuition of attention

FlashTorch demo 1

Visualising saliency maps with backpropagation


Install FlashTorch & load an image



$ pip install flashtorch

...
In [2]:
from flashtorch.utils import load_image

image = load_image('../../examples/images/great_grey_owl_01.jpg')

plt.imshow(image)
plt.title('Original image')
plt.axis('off');

Apply transformations


In [3]:
from flashtorch.utils import apply_transforms, denormalize, format_for_plotting

input_ = apply_transforms(image)

print(f'Before: {type(image)}')
print(f'After: {type(input_)}, {input_.shape}')

plt.imshow(format_for_plotting(denormalize(input_)))
plt.title('Input tensor')
plt.axis('off');
Before: <class 'PIL.Image.Image'>
After: <class 'torch.Tensor'>, torch.Size([1, 3, 224, 224])

Create a Backprop object with a pre-trained model


In [4]:
from flashtorch.saliency import Backprop

model = models.alexnet(pretrained=True)

backprop = Backprop(model)
Signature:

    backprop.calculate_gradients(input_, target_class=None, take_max=False)

Calculate the gradients of target class w.r.t. input


In [5]:
from flashtorch.utils import ImageNetIndex 

imagenet = ImageNetIndex()
target_class = imagenet['great grey owl']

print(f'Traget class index: {target_class}')

gradients = backprop.calculate_gradients(input_, target_class)

max_gradients = backprop.calculate_gradients(input_, target_class, take_max=True)

print(type(gradients), gradients.shape)
print(type(max_gradients), max_gradients.shape)
Traget class index: 24
<class 'torch.Tensor'> torch.Size([3, 224, 224])
<class 'torch.Tensor'> torch.Size([1, 224, 224])

Let's visualise gradients


In [6]:
from flashtorch.utils import visualize

visualize(input_, gradients, max_gradients)
Pixels where the animal is present have the strongest positive effects.
But it's quite noisy...

FlashTorch demo 2

Visualising saliency maps with guided backpropagation


Guided backpropagation


  • Additional guidance from the higher layers during backprop

  • Masks out neurons for which at least one of below is negative

    • Activation value during forward pass (pre-ReLU)
    • Gradients during backward pass
  • Prevents the flow of gradients i.e. neurons which decrease activation of layer of interest

Calculate the gradients with guided backprop


In [7]:
guided_gradients = backprop.calculate_gradients(input_, target_class, guided=True)

max_guided_gradients = backprop.calculate_gradients(input_, target_class, take_max=True, guided=True)

visualize(input_, guided_gradients, max_guided_gradients)
Now that's much less noisy!
Pixels around the head and eyes have the strongest positive effects.

What about a jay?


In [9]:
visualize(input_, guided_gradients, max_guided_gradients)

Or an oystercatcher?


In [11]:
visualize(input_, guided_gradients, max_guided_gradients)

FlashTorch demo 3

Gaining additional insights on transfer learning


Transfer learning


  • A model developed for a task is reused as a starting point for another task

  • Pre-trained models often used in computer visions & natural language processing tasks

  • Save compute & time resources

Building a flower classifier


<-- From: Densenet model, pre-trained on ImageNet (1000 classes)

--> To: Flower classifier to recognise 102 species of flowers, using a dataset from VGG group.

Load a target image


In [12]:
image = load_image('../../examples/images/foxglove.jpg')

plt.imshow(image)
plt.title('Foxglove')
plt.axis('off');

Pre-trained model - no additional training


In [15]:
backprop = Backprop(pretrained_model)

guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)
/Users/misao/Projects/personal/flashtorch/flashtorch/saliency/backprop.py:93: UserWarning: The predicted class does not equal the
                target class. Calculating the gradient with respect to the
                predicted class.
  predicted class.'''))

Trained model - test accuracy 98.7%


In [16]:
backprop = Backprop(trained_model)

guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)
Trained model pays attention specifically to the distinguising pattern of this particular specie.

Thank you!